
import os
import re
import sys
import base64
import openai
import subprocess
from pathlib import Path
from threading import Timer
import inference

model = "GPT4o"
PARAMS = {
    "temperature": 0.6,
    "num_samples": 8,
    "max_attempts": 60,
    "execution_timeout": 20,
    "mem_capacity": 5,
    "outer_loops": 15,  # N: number of outer loops
    "inner_loops": 4   # M: number of inner loops (total attempts = N*M)
}

def read_file(path):
    with open(path, 'r', encoding='utf-8') as f:
        return f.read()

def write_file(path, content):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, 'w', encoding='utf-8') as f:
        f.write(content)

def run_matplotlib_code(code_str, output_png_path):
    out_dir = os.path.dirname(output_png_path)
    os.makedirs(out_dir, exist_ok=True)
    code_file_path = os.path.join(out_dir, "matplot_code.py")
    with open(code_file_path, 'w', encoding='utf-8') as f:
        f.write(code_str)
    cmd = ["python", code_file_path]
    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

    def kill_proc(p):
        p.kill()

    timer = Timer(PARAMS["execution_timeout"], kill_proc, [proc])
    try:
        timer.start()
        stdout, stderr = proc.communicate()
    finally:
        timer.cancel()

    if proc.returncode != 0:
        return f"Execution error: {stderr.decode('utf-8')}"
    else:
        # Check that PNG was created and is non-empty
        if not os.path.isfile(output_png_path) or os.path.getsize(output_png_path) == 0:
            return "No PNG generated or PNG file is empty."
    return None

def verify_goal_diagram(
    problem_description,
    initial_state,
    goal_state,
    initial_diagram_path,
    initial_diagram_reasoning,
    goal_diagram_path,
    model
):
    if not os.path.isfile(initial_diagram_path):
        return False, "PNG file not found or missing."
    if not os.path.isfile(goal_diagram_path):
        return False, "PNG file not found or missing."

    prompt_parts = [
    f"""
    Your task is to verify a suggested diagram for a goal state against these requirements:
    1) Problem description:
    {problem_description}
    2) This is the initial state of the problem
    {initial_state}
    3) The goal state is:
    {goal_state}
    We have the following diagram for the initial state:""",
    {"image_path": initial_diagram_path},
    "Also here's the reasoning behind how objects are visualized in the intial state, along with am explanation of what the meaning behind each shape, color, and different sizes and locations is in the diagram.",
    {initial_diagram_reasoning},
    "Now we have the suggested diagram below for the goal state:",
    {"image_path": goal_diagram_path},
    f"""
    Your task is to verify the suggested goal diagram. Check the following carefully, think step by step and explain.
    
    - For objects present in the initial state that have no constraints in the goal, does the diagram mention or handle them properly? (e.g. by placing them in the default location if it clearly makes sense, or using a text note if we can make no conclusion about their final location)?
    - Does the diagram visualize all of the objects and constraints described in the goal state? (For this you have to make sure no 2 objects overlap)
    - Is the legend visualized to the side of the diagram such that it does not overlap any part of the diagram?
    - Does the reasoning provided for the initial state's diagram match what is shown in the diagram? (I.e. is the status of objects accurately visualized with respect to the legend and the diragram reasoning?)
    - Does each object in the diagram have a text lable inside it? Is the text clearly readable?
    - Does the diagram appear physically plausible (e.g., no floating objects if not approperiate, no misaliged cells in a grid, etc)?
    - Is the status of the objects acurately visualized with respect to the legend and the reasoning?
    Finally, provide "yes" if this diagram is correct and satisfies the requirement above, or "no" otherwise, with the following format:
    ```yes_no
    <yes or no>
    ```
    If your final answer is "no", provide a short phrase describing the problem with the diagram in this format:
    ```error
    <error description>
    ```
    """
    ]
    response = inference.get_model_response(prompt_parts, model)
    validity = inference.extract_content(response, "yes_no")
    error_description = inference.extract_content(response, "error")
    return validity, error_description

def generate_goal_diagram_code(
    problem_description,
    initial_state,
    goal_state,
    initial_code,
    initial_diagram_reasoning,
    initial_diagram_path,
    prev_section,
    prev_error,
    output_path,
    model,
    temp
):
    if not os.path.isfile(initial_diagram_path):
        print("PNG file not found or missing.")
        return False, "PNG file not found or missing."

    prompt_text = [
    f"""
    You are generating Python matplotlib code that draws a diagram for the goal state of the following problem:
    Problem Description:
    {problem_description}
    Initial_state:
    {initial_state}
    Goal state:
    {goal_state}
    We have drawn a diagram for the initial state using the code below. Use it as a reference:
    {initial_code},
    Also here's the reasoning behind the diagram code, along with explanation of what the meaning behind each shape, color, and different sizes and locations is in the diagram.
    {initial_diagram_reasoning}""",
    "The code above generated the following diagram for the initial state:",
    {"image_path": initial_diagram_path},
    f"""IMPORTANT:
    1) If no constraints are given for an object in a goal state but the object is present in the initial state, visualize the object in a default location or default status if it exists. Important: if no clear decision can be made about the final status or location of the object, draw the object at the bottom of the page, below the diagram with status the object has no constraints.
    2) Ensure no two objects overlap in the diagram. Also ensure the legend does not overlap the objects in the figure.
    3) Place text labels and status for each object inside their shape. Make sure any text is clear and readable.
    4) Reuse consistent shapes/colors for the same object types from the initial state code and the reasoning provided for how to encode different objects. (i.e. the meaning behind each shape (e.g. what a rectangle represents vs. what a circle represents), colors (e.g. meaning behind usage of red vs. green), sizes should be consistent across the 2 diagrams)
    5) You should not add any additional constraints to the goal state, if no constraints are given about an object, visualize it default status or draw it at the bottom of the page with status no constraints.
    6) The final line of the code must save the figure to:
    {output_path}""",
    f"""Avoid repeating the previously encountered error:
    {prev_error}""" if prev_error else "",
    (f"""Below are previously generated code snippets for drawing the goal diagram.
    {prev_section}
    Make sure you produce a code that is different from the codes above:
    """ if prev_section else ""),
    """Provide your code in the format:
    ```code
    <python code>
    ```
    Make sure to include your reasoning above the code block.
    """
    ]
    response = inference.get_model_response(prompt_text, model, temp)
    code_str = inference.extract_content(response, "code", remove_new_lines=False)
    return code_str

def rank_goal_diagram_images(
    problem_description,
    initial_state,
    goal_state,
    initial_diagram_path,
    diagram_png_paths,
    model
):
    if not os.path.isfile(initial_diagram_path):
        return False, "PNG file not found or missing."

    prompt_parts = [
    f"""
    We have the following problem:
    {problem_description}
    Initial state:
    {initial_state}
    Goal state:
    {goal_state}""",
    "This is a visualization of the initial state:",
    {"image_path": initial_diagram_path},
    f"""
    We have a few candidate visualizations for the goal state. Your goal is to rank the goal state visualizations, from best to worst, based on how accurate, clear, and intuitive they are.
    You must consider the following:
    - Which diagram describes the scene most intuitively and accurately?
    - Whether the constraints in the goal state are accurately visualized.
    - All the objects in the goal state must be clearly drawn with no 2 objects overlapping.
    - For objects present in the initial state that have no constraints in the goal, the diagrams that handle thses objects by visualizaing them in a default position/status, if it makes sense and there's a plausible default, otherwise if no plausible default status for such objects is deductible, diagrams must use a text note describing that the object has no constraint at the bottom of the diagram.
    - There should be some information in the goal diagram about any objects that was in the initial state.
    - The meaning behind shapes/colors/legend in the goal state is consistent with that of the initial state.
    - Whether the diagram and the placement of the objects is physically plausible.
    - Clarity of all the text in the diagram, with minimal overlaps and good readability. Also the textual label and status for each object should be placed inside its shape.
    """
    ]
    for idx, path in enumerate(diagram_png_paths):
        prompt_parts.append(f"Diagram {idx+1} (index {idx}):")
        prompt_parts.append({"image_path": path})
    
    prompt_parts.append("""
    The candidate diagrams are attached above.
    After reviewing all diagrams, provide a final ranking in the format:
    ```ranking
    <diagram indices in order from best to worst, e.g. 2,0,1>
    ```
    Above the code block, iterate through each diagram and write statement about its weaknesses and strengths in the original order given. Then rank the diagrams. Think step by step about the ranking and explain.
    """)

    response = inference.get_model_response(prompt_parts, model)
    reasoning = response
    ranking = inference.extract_content(response, "ranking")
    if not ranking:
        return diagram_png_paths[::-1], reasoning

    ranked_ids = [int(id.strip()) for id in ranking if id.strip().isdigit()]
    if all(1 <= i <= len(diagram_png_paths) for i in ranked_ids) and len(ranked_ids) == len(diagram_png_paths):
        new_order = [diagram_png_paths[i-1] for i in ranked_ids]
        return new_order, reasoning
    else:
        return diagram_png_paths, reasoning

def get_1shot_diagram_code_goal(domain_name):
    base_dir = os.path.join(domain_name)

    # Load problem description, initial state, and goal state
    problem_description = read_file(os.path.join(base_dir, f"{domain_name}_domain.txt"))

    initial_state_path = os.path.join(base_dir, "initial_state.txt")
    if os.path.isfile(initial_state_path):
        initial_state = read_file(initial_state_path)
    else:
        print("No initial state found.")
        return

    goal_state_path = os.path.join(base_dir, "goal_state.txt")
    if os.path.isfile(goal_state_path):
        goal_state = read_file(goal_state_path)
    else:
        print("No goal state found.")
        return

    initial_diagram_code_path = os.path.join(base_dir, "one_shot", "ini_diagram_code", "best_candidate_code.py")
    if os.path.isfile(initial_diagram_code_path):
        initial_diagram_code = read_file(initial_diagram_code_path)
    else:
        print("No initial state code found.")
        return

    initial_diagram_reasoning_path = os.path.join(base_dir, "one_shot", "ini_diagram_code", "best_candidate_reasoning.txt")
    if os.path.isfile(initial_diagram_reasoning_path):
        initial_diagram_reasoning = read_file(initial_diagram_reasoning_path)
    else:
        print("No initial state diagram reasoning found.")
        return

    initial_diagram_path = os.path.join(base_dir, "one_shot", "ini_diagram_code", "best_diagram.png")

    # Prepare directories for goal diagram
    base_task_dir = os.path.join(domain_name, "one_shot", "goal_diagram_code")
    attempts_dir = os.path.join(base_task_dir, "attempts")
    os.makedirs(attempts_dir, exist_ok=True)

    # -------------------------------------------------------------------------
    # These are the global lists of valid diagrams we will finally rank
    valid_diagram_codes = []
    valid_diagram_pngs = []
    # -------------------------------------------------------------------------

    attempt_count = 0

    # N outer loops, M inner loops
    for outer_idx in range(PARAMS["outer_loops"]):
        # Reset memory each time we enter a new outer loop
        prev_solutions_text = []
        prev_error = None

        # ---------------------------------------------------------------------
        # ADDED: local lists for each outer loop
        outer_valid_codes = []
        outer_valid_pngs = []
        # ---------------------------------------------------------------------

        for inner_idx in range(PARAMS["inner_loops"]):
            # If we've already collected enough valid samples, break out
            if len(valid_diagram_codes) >= PARAMS["num_samples"]:
                break

            # If we've reached the total max_attempts, break out
            if attempt_count >= PARAMS["max_attempts"]:
                break

            attempt_count += 1

            # Prepare the snippet of previous solutions
            prev_section = "\n".join(
                [
                    f"-- A previously generated code snippet:\n{code}\n"
                    for code in prev_solutions_text[:PARAMS["mem_capacity"]]
                ]
            )

            output_png_path = os.path.join(attempts_dir, f"goal_diagram_attempt_{attempt_count}.png")

            # Generate code for the goal diagram
            new_code = generate_goal_diagram_code(
                problem_description,
                initial_state,
                goal_state,
                initial_diagram_code,
                initial_diagram_reasoning,
                initial_diagram_path,
                prev_section,
                prev_error,
                output_png_path,
                model,
                temp=PARAMS["temperature"]
            )

            # If generation returned nothing or an error message
            if not new_code or isinstance(new_code, tuple):
                print(f"[ERROR] Attempt {attempt_count} returned invalid code or error.")
                continue

            print(f"[INFO] Generated code attempt #{attempt_count} for goal diagram.")
            code_file_path = os.path.join(attempts_dir, f"goal_attempt_{attempt_count}.py")
            write_file(code_file_path, new_code)

            # Test-run the code
            error = run_matplotlib_code(new_code, output_png_path)
            if error:
                prev_error = f"Attempt {attempt_count} execution error: {error}"
                with open(code_file_path, 'a', encoding='utf-8') as f:
                    f.write(f"\n# EXECUTION ERROR:\n# {error}\n")
                print(f"[ERROR] Attempt {attempt_count} failed to produce PNG. {error}")
                continue

            # Verify
            is_valid, err_msg = verify_goal_diagram(
                problem_description,
                initial_state,
                goal_state,
                initial_diagram_path,
                initial_diagram_reasoning,
                output_png_path,
                model
            )
            if not is_valid:
                prev_error = f"Verification failed for attempt {attempt_count}: {err_msg}"
                with open(code_file_path, 'a', encoding='utf-8') as f:
                    f.write(f"\n# VERIFICATION FAILED:\n# {err_msg}\n")
                print(f"[ERROR] Attempt {attempt_count} verification failed: {err_msg}")
                continue
            else:
                prev_error = None
                print(f"[INFO] Attempt {attempt_count} verified successfully!")

                # Instead of adding directly to the global list, add to outer loop local list
                outer_valid_codes.append(new_code)
                outer_valid_pngs.append(output_png_path)

                # Add to 'prev_solutions_text' so subsequent inner attempts differ
                prev_solutions_text.append(new_code)

        # End of the inner loop. Keep only the last 2 valid solutions from this outer loop
        if outer_valid_codes:
            outer_valid_codes = outer_valid_codes[-2:]
            outer_valid_pngs  = outer_valid_pngs[-2:]

            valid_diagram_codes.extend(outer_valid_codes)
            valid_diagram_pngs.extend(outer_valid_pngs)

            # If we've already collected enough valid samples, break out
            if len(valid_diagram_codes) >= PARAMS["num_samples"]:
                break
        # ---------------------------------------------------------------------

        # Stop entire process if we've hit max_attempts
        if attempt_count >= PARAMS["max_attempts"]:
            break

    # After going through N×M attempts (or fewer if we broke early):
    if not valid_diagram_codes:
        error = "[ERROR] No valid goal diagram codes found; pipeline ended with no successes."
        print(error)
        return False, error

    # Ranking the set of valid PNGs
    print("[INFO] Ranking the valid goal diagrams now...")
    ranked_pngs, reasoning = rank_goal_diagram_images(
        problem_description,
        initial_state,
        goal_state,
        initial_diagram_path,
        valid_diagram_pngs,
        model
    )

    # The best diagram is the first in that newly ranked order
    best_png = ranked_pngs[0]
    best_idx = valid_diagram_pngs.index(best_png)
    best_code = valid_diagram_codes[best_idx]

    best_dir = base_task_dir
    os.makedirs(best_dir, exist_ok=True)

    best_code_file = os.path.join(best_dir, "best_candidate_code.py")
    write_file(best_code_file, best_code)

    best_diagram_file = os.path.join(best_dir, "best_diagram.png")
    with open(best_png, 'rb') as src, open(best_diagram_file, 'wb') as dst:
        dst.write(src.read())

    print("[SUCCESS] Found a valid goal diagram code solution.")
    print(f"[INFO] Best code saved at: {best_code_file}")
    print(f"[INFO] Best diagram saved at: {best_diagram_file}")
    return True, None

def main():
    if len(sys.argv) < 2:
        print("Usage: python one_shot_goal_diagram_code.py <domain_name>")
        sys.exit(1)
    domain_name = sys.argv[1]
    print("Started on getting verified and ranked diagram code for the goal state")
    get_1shot_diagram_code_goal(domain_name)

if __name__ == "__main__":
    main()